library(tidyverse)
Registered S3 method overwritten by 'dplyr':
method from
print.rowwise_df
Registered S3 methods overwritten by 'dbplyr':
method from
print.tbl_lazy
print.tbl_sql
[30m── [1mAttaching packages[22m ────────────────────────────────────────────────────────────────────── tidyverse 1.3.0 ──[39m
[30m[32m✓[30m [34mggplot2[30m 3.2.1 [32m✓[30m [34mpurrr [30m 0.3.3
[32m✓[30m [34mtibble [30m 2.1.3 [32m✓[30m [34mdplyr [30m 0.8.4
[32m✓[30m [34mtidyr [30m 1.0.2 [32m✓[30m [34mstringr[30m 1.4.0
[32m✓[30m [34mreadr [30m 1.3.1 [32m✓[30m [34mforcats[30m 0.4.0[39m
[30m── [1mConflicts[22m ───────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
[31mx[30m [34mdplyr[30m::[32mfilter()[30m masks [34mstats[30m::filter()
[31mx[30m [34mdplyr[30m::[32mlag()[30m masks [34mstats[30m::lag()[39m
library(lme4)
Loading required package: Matrix
Attaching package: ‘Matrix’
The following objects are masked from ‘package:tidyr’:
expand, pack, unpack
library(lmerTest)
Attaching package: ‘lmerTest’
The following object is masked from ‘package:lme4’:
lmer
The following object is masked from ‘package:stats’:
step
library(plotrix)
library(stringr)
library(readxl)
library(RColorBrewer)
library(mvtnorm)
library(mgcv)
Loading required package: nlme
Attaching package: ‘nlme’
The following object is masked from ‘package:lme4’:
lmList
The following object is masked from ‘package:dplyr’:
collapse
This is mgcv 1.8-31. For overview type 'help("mgcv-package")'.
# Compute the log-likelihood of a new dataset using a fit lme4 model.
logLik_test <- function(lm, test_X, test_y) {
predictions <- predict(lm, test_X, re.form=NA)
# Get std.dev. of residual, estimated from train data
stdev <- sigma(lm)
# For each prediction--observation, get the density p(obs | N(predicted, model_sigma)) and reduce
density <- sum(dnorm(test_y, predictions, stdev, log=TRUE))
return(density)
}
# Get per-prediction log-likelihood
logLik_test_per <- function(lm, test_X, test_y) {
predictions <- predict(lm, test_X, re.form=NA)
# Get std.dev. of residual, estimated from train data
stdev <- sigma(lm)
# For each prediction--observation, get the density p(obs | N(predicted, model_sigma))
densities <- dnorm(test_y, predictions, stdev, log=TRUE)
return(densities)
}
# Compute MSE of a new dataset using a fit lme4 model.
mse_test <- function(lm, test_X, test_y) {
return(mean((predict(lm, test_X, re.form=NA) - test_y) ^ 2))
}
#Sanity checks
#mylm <- gam(psychometric ~ s(surprisal, bs = "cr", k = 20) + s(prev_surp, bs = "cr", k = 20) + te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr"), data=train_data)
#c(logLik(mylm), logLik_test(mylm, train_data, train_data$psychometric))
#logLik_test(mylm, test_data, test_data$psychometric)
data = read.csv("../data/harmonized_results.csv")
all_data = data %>%
mutate(seed = as.factor(seed)) %>%
group_by(corpus, model, training, seed) %>%
mutate(prev_surp = lag(surprisal),
prev_code = lag(code),
prev_len = lag(len),
prev_freq = lag(freq),
prev_surp = lag(surprisal),
prev2_freq = lag(prev_freq),
prev2_code = lag(prev_code),
prev2_len = lag(prev_len),
prev2_surp = lag(prev_surp),
prev3_freq = lag(prev2_freq),
prev3_code = lag(prev2_code),
prev3_len = lag(prev2_len),
prev3_surp = lag(prev2_surp)) %>%
ungroup() %>%
# Filter back three for the dundee corpus. Filter back 1 for all other corpora
filter((corpus == "dundee" & code == prev2_code + 2) | (corpus != "dundee" & code == prev_code + 1)) %>%
select(-prev_code, -prev2_code, -prev3_code) %>%
drop_na()
all_data = all_data %>%
mutate(
model = as.character(model),
model = if_else(model == "gpt-2", "gpt2", model),
model = as.factor(model))
# Compute linear model stats for the given training data subset and full test data.
# Automatically subsets the test data to match the relevant group for which we are training a linear model.
get_lm_data <- function(df, test_data, formula, store_env) {
#this_lm <- gam(formula, data=df);
this_lm = lm(formula, data=df)
this_test_data <- semi_join(test_data, df, by=c("training", "model", "seed", "corpus"));
# Save lm to the global env so that we can access residuals later.
lm_name = unique(paste(df$model, df$training, df$seed, df$corpus))[1]
assign(lm_name, this_lm, envir=store_env)
summarise(df,
log_lik = as.numeric(logLik(this_lm, REML = F)),
test_lik = logLik_test(this_lm, this_test_data, this_test_data$psychometric),
test_mse = mse_test(this_lm, this_test_data, this_test_data$psychometric))
}
# For a previously fitted lm stored in store_env, get the residuals on test data of the relevant data subset.
get_lm_residuals <- function(df, store_env) {
# Retrieve the relevant lm.
lm_name = unique(paste(df$model, df$training, df$seed, df$corpus))[1]
this_lm <- get(lm_name, envir=store_env)
mutate(df,
likelihood = logLik_test_per(this_lm, df, df$psychometric),
resid = df$psychometric - predict(this_lm, df, re.form=NA))
}
#####
# Define regression formulae.
# Eye-tracking regression: only use surprisal and previous surprisal; SPRT regression: use 2-back features.
#baseline_rt_regression = psychometric ~ te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr")
#baselie_sprt_regression = psychometric ~ te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr") + te(prev2_freq, prev2_len, bs = "cr")
#full_rt_regression = (psychometric ~ s(surprisal, bs = "cr", k = 20) + s(prev_surp, bs = "cr", k = 20)
#+ te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr"))
#full_sprt_regression = (psychometric ~ s(surprisal, bs = "cr", k = 20) + s(prev_surp, bs = "cr", k = 20) + s(prev2_surp, bs = "cr", k = 20)
#+ te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr") + te(prev2_freq, prev2_len, bs = "cr"))
baseline_rt_regression = psychometric ~ freq + prev_freq + prev2_freq + prev3_freq + len + prev_len + prev2_len + prev3_len
baseline_sprt_regression = psychometric ~ freq + prev_freq + len + prev_len
full_rt_regression = psychometric ~ surprisal + prev_surp + prev2_surp + prev3_surp + freq + prev_freq + prev2_freq + prev3_freq + len + prev_len + prev2_len + prev3_len
full_sprt_regression = psychometric ~ surprisal + prev_surp + freq + prev_freq + len + prev_len
#####
# Prepare frames/environments for storing results/objects.
baseline_results = data.frame()
full_model_results = data.frame()
baseline_residuals = data.frame()
full_residuals = data.frame()
#Randomly shuffle the data
all_data<-all_data[sample(nrow(all_data)),]
#Create K equally size folds
K = 5
folds <- cut(seq(1,nrow(all_data)),breaks=K,labels=FALSE)
#Perform 10 fold cross validation
baseline_corpus = function(corpus, df, test_data, env) {
if(corpus == "dundee") {
get_lm_data(df, test_data, baseline_rt_regression, env)
} else {
get_lm_data(df, test_data, baseline_sprt_regression, env)
}
}
full_model_corpus = function(corpus, df, test_data, env) {
if(corpus[1] == "dundee") {
get_lm_data(df, test_data, full_rt_regression, env)
} else {
get_lm_data(df, test_data, full_sprt_regression, env)
}
}
for(i in 1:K) {
#Segement your data by fold using the which() function
testIndexes <- which(folds==i,arr.ind=TRUE)
test_data <- all_data[testIndexes, ]
train_data <- all_data[-testIndexes, ]
# Prepare a new Environment in which we store fitted LMs, which we'll query later for residuals.
baseline_env = environment()
full_env = environment()
# Compute a baseline linear model for each model--training--seed--RT-corpus combination.
baselines = train_data %>%
group_by(model, training, seed, corpus) %>%
print(model) %>%
do(baseline_corpus(unique(.$corpus), ., test_data, baseline_env)) %>%
ungroup() %>%
mutate(seed = as.factor(seed),
fold = i)
baseline_results = rbind(baseline_results, baselines)
# Compute a full linear model for each model--training--seed-RT-corpus combination
full_models = train_data %>%
group_by(model, training, seed, corpus) %>%
do(full_model_corpus(unique(.$corpus), ., test_data, full_env)) %>%
ungroup() %>%
mutate(seed = as.factor(seed),
fold = i)
full_model_results = rbind(full_model_results, full_models)
fold_baseline_residuals = test_data %>%
group_by(model, training, seed, corpus) %>%
do(get_lm_residuals(., baseline_env)) %>%
ungroup()
baseline_residuals = rbind(baseline_residuals, fold_baseline_residuals)
fold_full_residuals = test_data %>%
group_by(model, training, seed, corpus) %>%
do(get_lm_residuals(., full_env)) %>%
ungroup()
full_residuals = rbind(full_residuals, fold_full_residuals)
}
|========================================================================================= | 51% ~2 s remaining
|============================================================================================ | 52% ~2 s remaining
|================================================================================================= | 55% ~2 s remaining
|====================================================================================================== | 58% ~2 s remaining
|============================================================================================================ | 62% ~2 s remaining
|================================================================================================================= | 65% ~1 s remaining
|======================================================================================================================= | 68% ~1 s remaining
|============================================================================================================================ | 71% ~1 s remaining
|================================================================================================================================= | 74% ~1 s remaining
|======================================================================================================================================= | 77% ~1 s remaining
|============================================================================================================================================ | 80% ~1 s remaining
|================================================================================================================================================== | 83% ~1 s remaining
|======================================================================================================================================================= | 86% ~1 s remaining
|============================================================================================================================================================= | 89% ~0 s remaining
|================================================================================================================================================================== | 92% ~0 s remaining
|======================================================================================================================================================================= | 95% ~0 s remaining
|============================================================================================================================================================================= | 98% ~0 s remaining
|============================================================================================== | 54% ~2 s remaining
|================================================================================================= | 55% ~2 s remaining
|====================================================================================================== | 58% ~2 s remaining
|============================================================================================================ | 62% ~1 s remaining
|================================================================================================================= | 65% ~1 s remaining
|======================================================================================================================= | 68% ~1 s remaining
|============================================================================================================================ | 71% ~1 s remaining
|================================================================================================================================= | 74% ~1 s remaining
|======================================================================================================================================= | 77% ~1 s remaining
|============================================================================================================================================ | 80% ~1 s remaining
|================================================================================================================================================== | 83% ~1 s remaining
|======================================================================================================================================================= | 86% ~1 s remaining
|============================================================================================================================================================= | 89% ~0 s remaining
|================================================================================================================================================================== | 92% ~0 s remaining
|======================================================================================================================================================================= | 95% ~0 s remaining
|============================================================================================================================================================================= | 98% ~0 s remaining
|====================================================================================================== | 58% ~1 s remaining
|============================================================================================================ | 62% ~1 s remaining
|================================================================================================================= | 65% ~1 s remaining
|======================================================================================================================= | 68% ~1 s remaining
|============================================================================================================================ | 71% ~1 s remaining
|================================================================================================================================= | 74% ~1 s remaining
|======================================================================================================================================= | 77% ~1 s remaining
|============================================================================================================================================ | 80% ~1 s remaining
|================================================================================================================================================== | 83% ~1 s remaining
|======================================================================================================================================================= | 86% ~0 s remaining
|============================================================================================================================================================= | 89% ~0 s remaining
|================================================================================================================================================================== | 92% ~0 s remaining
|===================================================================================================================================================================== | 94% ~0 s remaining
|======================================================================================================================================================================= | 95% ~0 s remaining
|============================================================================================================================================================================= | 98% ~0 s remaining
|============================================================================== | 45% ~3 s remaining
|================================================================================= | 46% ~2 s remaining
|====================================================================================== | 49% ~2 s remaining
|============================================================================================ | 52% ~2 s remaining
|================================================================================================= | 55% ~2 s remaining
|====================================================================================================== | 58% ~2 s remaining
|============================================================================================================ | 62% ~2 s remaining
|================================================================================================================= | 65% ~2 s remaining
|======================================================================================================================= | 68% ~1 s remaining
|============================================================================================================================ | 71% ~1 s remaining
|================================================================================================================================= | 74% ~1 s remaining
|======================================================================================================================================= | 77% ~1 s remaining
|============================================================================================================================================ | 80% ~1 s remaining
|================================================================================================================================================== | 83% ~1 s remaining
|======================================================================================================================================================= | 86% ~1 s remaining
|============================================================================================================================================================= | 89% ~0 s remaining
|================================================================================================================================================================== | 92% ~0 s remaining
|======================================================================================================================================================================= | 95% ~0 s remaining
|============================================================================================================================================================================= | 98% ~0 s remaining
|====================================================================================================== | 58% ~1 s remaining
|============================================================================================================ | 62% ~1 s remaining
|================================================================================================================= | 65% ~1 s remaining
|======================================================================================================================= | 68% ~1 s remaining
|============================================================================================================================ | 71% ~1 s remaining
|================================================================================================================================= | 74% ~1 s remaining
|======================================================================================================================================= | 77% ~1 s remaining
|============================================================================================================================================ | 80% ~1 s remaining
|================================================================================================================================================== | 83% ~1 s remaining
|======================================================================================================================================================= | 86% ~0 s remaining
|============================================================================================================================================================= | 89% ~0 s remaining
|================================================================================================================================================================== | 92% ~0 s remaining
|======================================================================================================================================================================= | 95% ~0 s remaining
|============================================================================================================================================================================= | 98% ~0 s remaining
|============================================================================================== | 54% ~2 s remaining
|================================================================================================= | 55% ~2 s remaining
|====================================================================================================== | 58% ~2 s remaining
|============================================================================================================ | 62% ~1 s remaining
|================================================================================================================= | 65% ~1 s remaining
|======================================================================================================================= | 68% ~1 s remaining
|============================================================================================================================ | 71% ~1 s remaining
|================================================================================================================================= | 74% ~1 s remaining
|======================================================================================================================================= | 77% ~1 s remaining
|============================================================================================================================================ | 80% ~1 s remaining
|================================================================================================================================================== | 83% ~1 s remaining
|==================================================================================================================================================== | 85% ~1 s remaining
|======================================================================================================================================================= | 86% ~1 s remaining
|============================================================================================================================================================= | 89% ~0 s remaining
|================================================================================================================================================================== | 92% ~0 s remaining
|======================================================================================================================================================================= | 95% ~0 s remaining
|============================================================================================================================================================================= | 98% ~0 s remaining
|====================================================================================================== | 58% ~1 s remaining
|============================================================================================================ | 62% ~1 s remaining
|================================================================================================================= | 65% ~1 s remaining
|==================================================================================================================== | 66% ~1 s remaining
|======================================================================================================================= | 68% ~1 s remaining
|============================================================================================================================ | 71% ~1 s remaining
|================================================================================================================================= | 74% ~1 s remaining
|======================================================================================================================================= | 77% ~1 s remaining
|============================================================================================================================================ | 80% ~1 s remaining
|================================================================================================================================================== | 83% ~1 s remaining
|======================================================================================================================================================= | 86% ~0 s remaining
|============================================================================================================================================================= | 89% ~0 s remaining
|================================================================================================================================================================== | 92% ~0 s remaining
|===================================================================================================================================================================== | 94% ~0 s remaining
|======================================================================================================================================================================= | 95% ~0 s remaining
|============================================================================================================================================================================= | 98% ~0 s remaining
|================================================================================================= | 55% ~2 s remaining
|====================================================================================================== | 58% ~2 s remaining
|============================================================================================================ | 62% ~1 s remaining
|================================================================================================================= | 65% ~1 s remaining
|======================================================================================================================= | 68% ~1 s remaining
|============================================================================================================================ | 71% ~1 s remaining
|================================================================================================================================= | 74% ~1 s remaining
|======================================================================================================================================= | 77% ~1 s remaining
|============================================================================================================================================ | 80% ~1 s remaining
|=============================================================================================================================================== | 82% ~1 s remaining
|================================================================================================================================================== | 83% ~1 s remaining
|======================================================================================================================================================= | 86% ~1 s remaining
|============================================================================================================================================================= | 89% ~0 s remaining
|================================================================================================================================================================== | 92% ~0 s remaining
|======================================================================================================================================================================= | 95% ~0 s remaining
|============================================================================================================================================================================= | 98% ~0 s remaining
|====================================================================================================== | 58% ~1 s remaining
|============================================================================================================ | 62% ~1 s remaining
|================================================================================================================= | 65% ~1 s remaining
|======================================================================================================================= | 68% ~1 s remaining
|============================================================================================================================ | 71% ~1 s remaining
|================================================================================================================================= | 74% ~1 s remaining
|======================================================================================================================================= | 77% ~1 s remaining
|============================================================================================================================================ | 80% ~1 s remaining
|================================================================================================================================================== | 83% ~1 s remaining
|======================================================================================================================================================= | 86% ~0 s remaining
|============================================================================================================================================================= | 89% ~0 s remaining
|================================================================================================================================================================== | 92% ~0 s remaining
|======================================================================================================================================================================= | 95% ~0 s remaining
|============================================================================================================================================================================= | 98% ~0 s remaining
|============================================================================================ | 52% ~2 s remaining
|================================================================================================= | 55% ~2 s remaining
|====================================================================================================== | 58% ~2 s remaining
|============================================================================================================ | 62% ~1 s remaining
|================================================================================================================= | 65% ~1 s remaining
|======================================================================================================================= | 68% ~1 s remaining
|============================================================================================================================ | 71% ~1 s remaining
|=============================================================================================================================== | 72% ~1 s remaining
|================================================================================================================================= | 74% ~1 s remaining
|======================================================================================================================================= | 77% ~1 s remaining
|============================================================================================================================================ | 80% ~1 s remaining
|================================================================================================================================================== | 83% ~1 s remaining
|======================================================================================================================================================= | 86% ~1 s remaining
|============================================================================================================================================================= | 89% ~0 s remaining
|================================================================================================================================================================== | 92% ~0 s remaining
|======================================================================================================================================================================= | 95% ~0 s remaining
|============================================================================================================================================================================= | 98% ~0 s remaining
write.csv(full_residuals, "../data/analysis_checkpoints/full_residuals.csv")
write.csv(baseline_residuals, "../data/analysis_checkpoints/baseline_residuals.csv")
write.csv(full_model_results, "../data/analysis_checkpoints/full_model_result.csv")
write.csv(baseline_results, "../data/analysis_checkpoints/baseline_results.csv")
#full_model_results = read.csv("../data/analysis_checkpoints/ffull_model_results.csv")
#baseline_results = read.csv("../data/analysis_checkpoints/fbaseline_resultsb.csv")
# Join baseline models with full models and compare performance within-fold.
model_fold_deltas = baseline_results %>%
right_join(full_model_results, suffix=c(".baseline", ".full"),
by=c("model", "training", "seed", "corpus", "fold")) %>%
mutate(seed = as.factor(seed)) %>%
# Compute per-fold deltas.
group_by(model, training, seed, corpus, fold) %>%
mutate(delta_log_lik = test_lik.full - test_lik.baseline,
delta_mse = test_mse.full - test_mse.baseline) %>%
ungroup() %>%
select(model, training, seed, corpus, fold,
delta_log_lik, delta_mse)
# Now compute across-fold delta statistics for each model--training--seed--corpus.
model_deltas = model_fold_deltas %>%
group_by(model, training, seed, corpus) %>%
summarise(mean_delta_log_lik = sum(delta_log_lik),
sem_delta_log_lik = sd(delta_log_lik) / sqrt(length(delta_log_lik)),
mean_delta_mse = sum(delta_mse),
sem_delta_mse = sd(delta_mse) / sqrt(length(delta_mse)))
metric <- "ΔLogLik"
#metric <- "-ΔMSE"
# Select the relevant metric.
model_fold_deltas = model_fold_deltas %>%
# Retrieve the current test metric
mutate(delta_test = delta_log_lik) %>%
select(-delta_log_lik, -delta_mse)
# Select the relevant metric.
model_deltas = model_deltas %>%
# Retrieve the current test metric
mutate(delta_test_mean = mean_delta_log_lik,
delta_test_sem = sem_delta_log_lik) %>%
# mutate(delta_test_mean = mean_delta_mse,
# delta_test_sem = sem_delta_mse)
# Remove the raw metrics.
select(-mean_delta_log_lik, -sem_delta_log_lik,
-mean_delta_mse, -sem_delta_mse)
model_deltas
# Sanity check: training on train+test data should yield improved performance over training on just training data. (When evaluating on test data.)
# full_baselines = all_data %>%
# group_by(model, training, seed, corpus) %>%
# summarise(baseline_train_all_test_lik = logLik_test(lm(psychometric ~ len + freq + sent_pos, data=.), semi_join(test_data, ., by=c("training", "model", "seed", "corpus")), semi_join(test_data, ., by=c("training", "model", "seed", "corpus"))$psychometric)) %>%
# ungroup()
# full_baselines
#
# full_baselines %>%
# right_join(baselines, by=c("seed", "training", "model", "corpus")) %>%
# mutate(delta=baseline_train_all_test_lik-baseline_test_lik) %>%
# select(-baseline_lik) # %>%
# #select(-baseline_test_lik, -baseline_train_all_test_lik, -baseline_lik, -baseline_test_mse)
language_model_data = read.csv("../data/model_metadata.csv") %>%
mutate(model = as.character(model),
model = if_else(model == "gpt-2", "gpt2", model),
model = as.factor(model)) %>%
mutate(train_size = case_when(str_starts(training, "bllip-lg") ~ 42,
str_starts(training, "bllip-md") ~ 15,
str_starts(training, "bllip-sm") ~ 5,
str_starts(training, "bllip-xs") ~ 1)) %>%
mutate(seed = as.factor(seed)) %>%
select(-pid, -test_loss) %>%
distinct(model, training, seed, .keep_all = TRUE)
table(language_model_data$seed)
0 111 120 922 1111 3602 4301 7245 7877 28066 28068 44862 51272 64924 1581807512 1581807578 1581861474 1581955288
4 7 6 5 4 1 1 1 1 1 1 1 1 1 1 1 1 1
1582126320 1586986276 1587139950
1 1 1
table(model_deltas$seed)
111 120 922 1111 3602 4301 7245 7877 28066 28068 44862 51272 64924 1581807512 1581807578 1581861474 1581955288 1582126320
6 6 6 10 2 2 2 2 2 2 2 2 2 3 3 3 3 3
1586986276 1587139950
2 2
First join delta-metric data with model auxiliary data.
model_deltas = model_deltas %>%
merge(language_model_data, by = c("seed", "training", "model"), all=T) %>%
drop_na()
model_fold_deltas = model_fold_deltas %>%
merge(language_model_data, by = c("seed", "training", "model"), all=T) %>%
drop_na()
model_deltas
Also join on the original linear model data, rather than collapsing to delta-metrics. This will support regressions later on that don’t collapse across folds.
# Exclude ordered-neurons from all analyses.
model_deltas <- model_deltas %>%
filter(model != "ordered-neurons")
model_fold_deltas <- model_fold_deltas %>%
filter(model != "ordered-neurons")
model_deltas %>%
ggplot(aes(x=sg_score, y=delta_test_mean)) +
geom_errorbar(aes(ymin=delta_test_mean-delta_test_sem, ymax=delta_test_mean+delta_test_sem)) +
geom_smooth(method="lm", se=T) +
geom_point(stat="identity", position="dodge", alpha=1, size=3, aes(color=training, shape=model)) +
ylab(metric) +
xlab("Syntax Generalization Score") +
ggtitle("Syntactic Generalization vs. Predictive Power") +
scale_color_manual(values = c("bllip-lg"="#440154FF",
"bllip-md"="#39568CFF",
"bllip-sm"="#1F968BFF",
"bllip-xs"="#73D055FF",
"bllip-lg-gptbpe"="#888888",
"bllip-md-gptbpe"="#AAAAAA",
"bllip-sm-gptbpe"="#CCCCCC",
"bllip-xs-gptbpe"="#CCCCCC")) +
facet_grid(~corpus, scales="free") +
theme(axis.text=element_text(size=14),
strip.text.x = element_text(size=14),
legend.text=element_text(size=14),
axis.title=element_text(size=18),
legend.position = "bottom")
#ggsave("./cogsci_images/sg_loglik.png",height=5,width=6)
We control for effects of perplexity by relating the residuals of a performance ~ PPL regression to SG score.
# Prepare a residualized regression for x1 onto y, controlling for the effects of x2.
d_resid = model_fold_deltas %>%
drop_na() %>%
# Residualize delta metric w.r.t PPL for each model--training--seed--fold
group_by(corpus) %>%
mutate(resid.delta = resid(lm(delta_test ~ training:test_ppl))) %>%
ungroup() %>%
# Residualize SG score w.r.t. PPL for each training group
group_by(training) %>%
# NB no need for training:ppl interaction, since we're within-group.
mutate(resid.sg = resid(lm(sg_score ~ test_ppl))) %>%
ungroup() %>%
# Compute summary statistics across model--training--seed--corpus.
group_by(model, training, corpus, seed) %>%
summarise(resid.delta.mean = mean(resid.delta),
resid.delta.sem = sd(resid.delta) / sqrt(length(resid.delta)),
resid.sg.mean = mean(resid.sg),
resid.sg.sem = sd(resid.sg) / sqrt(length(resid.sg)))
# Now plot residual vs SG
d_resid %>%
#filter(corpus != "bnc-brown") %>%
ggplot(aes(x=resid.sg.mean, y=resid.delta.mean)) +
geom_errorbar(aes(xmin=resid.sg.mean - resid.sg.sem,
xmax=resid.sg.mean + resid.sg.sem,
ymin=resid.delta.mean - resid.delta.sem,
ymax=resid.delta.mean + resid.delta.sem), alpha=0.3) +
geom_smooth(method="lm", se=T) +
geom_point(stat="identity", position="dodge", alpha=1, size=4, aes(shape=model, color=training)) +
ylab(paste("Residual", metric)) +
xlab("Residual Syntax Generalization Score") +
ggtitle("Syntactic Generalization vs. Predictive Power") +
scale_color_manual(values = c("bllip-lg"="#440154FF",
"bllip-md"="#39568CFF",
"bllip-sm"="#1F968BFF",
"bllip-xs"="#73D055FF",
"bllip-lg-gptbpe"="#888888",
"bllip-md-gptbpe"="#AAAAAA",
"bllip-sm-gptbpe"="#CCCCCC",
"bllip-xs-gptbpe"="#CCCCCC")) +
facet_grid(.~corpus, scales="free") +
theme(axis.text=element_text(size=14),
strip.text.x = element_text(size=14),
legend.text=element_text(size=14),
axis.title=element_text(size=18),
legend.position = "right")
Ignoring unknown aesthetics: xmin, xmax
ggsave("../images/cuny2020/ppl_sg.png",height=4.5,width=11)
do_stepwise_regression = function(cur_corpus) {
regression_data = model_fold_deltas %>%
filter(corpus == cur_corpus)
# NB we're incorporating variance across folds into this regression, good!
print("----------------------")
print(cur_corpus)
lm1 = lm(delta_test ~ training:test_ppl, data = regression_data)
lm2 = lm(delta_test ~ training:test_ppl + sg_score, data = regression_data)
print(anova(lm1, lm2))
summary(lm2)
}
#do_stepwise_regression("bnc-brown")
do_stepwise_regression("dundee")
[1] "----------------------"
[1] "dundee"
Analysis of Variance Table
Model 1: delta_test ~ training:test_ppl
Model 2: delta_test ~ training:test_ppl + sg_score
Res.Df RSS Df Sum of Sq F Pr(>F)
1 136 25890
2 135 24555 1 1334.2 7.335 0.007638 **
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Call:
lm(formula = delta_test ~ training:test_ppl + sg_score, data = regression_data)
Residuals:
Min 1Q Median 3Q Max
-55.316 -4.612 -0.492 4.296 48.896
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 30.46868 7.63911 3.989 0.000108 ***
sg_score -32.47411 11.99048 -2.708 0.007638 **
trainingbllip-lg:test_ppl 0.22829 0.06897 3.310 0.001197 **
trainingbllip-lg-gptbpe:test_ppl 6.73457 0.23468 28.696 < 2e-16 ***
trainingbllip-md:test_ppl 0.12484 0.05902 2.115 0.036261 *
trainingbllip-md-gptbpe:test_ppl 5.72687 0.16970 33.747 < 2e-16 ***
trainingbllip-sm:test_ppl 0.01160 0.05160 0.225 0.822483
trainingbllip-sm-gptbpe:test_ppl 1.33549 0.04429 30.152 < 2e-16 ***
trainingbllip-xs:test_ppl -0.03110 0.03482 -0.893 0.373345
trainingbllip-xs-gptbpe:test_ppl 0.44107 0.01503 29.349 < 2e-16 ***
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 13.49 on 135 degrees of freedom
Multiple R-squared: 0.9754, Adjusted R-squared: 0.9738
F-statistic: 595.6 on 9 and 135 DF, p-value: < 2.2e-16
do_stepwise_regression("natural-stories")
[1] "----------------------"
[1] "natural-stories"
Analysis of Variance Table
Model 1: delta_test ~ training:test_ppl
Model 2: delta_test ~ training:test_ppl + sg_score
Res.Df RSS Df Sum of Sq F Pr(>F)
1 136 2032.6
2 135 2019.8 1 12.791 0.8549 0.3568
Call:
lm(formula = delta_test ~ training:test_ppl + sg_score, data = regression_data)
Residuals:
Min 1Q Median 3Q Max
-12.0616 -2.4161 0.2438 2.6515 10.6305
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 12.345412 2.190932 5.635 9.81e-08 ***
sg_score -3.179657 3.438928 -0.925 0.356820
trainingbllip-lg:test_ppl -0.002126 0.019782 -0.107 0.914588
trainingbllip-lg-gptbpe:test_ppl -0.252006 0.067308 -3.744 0.000267 ***
trainingbllip-md:test_ppl -0.024897 0.016928 -1.471 0.143683
trainingbllip-md-gptbpe:test_ppl -0.239400 0.048671 -4.919 2.49e-06 ***
trainingbllip-sm:test_ppl -0.047759 0.014798 -3.227 0.001568 **
trainingbllip-sm-gptbpe:test_ppl -0.061851 0.012703 -4.869 3.09e-06 ***
trainingbllip-xs:test_ppl -0.044518 0.009985 -4.458 1.72e-05 ***
trainingbllip-xs-gptbpe:test_ppl -0.020918 0.004310 -4.853 3.31e-06 ***
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 3.868 on 135 degrees of freedom
Multiple R-squared: 0.445, Adjusted R-squared: 0.4081
F-statistic: 12.03 on 9 and 135 DF, p-value: 8.343e-14
model_deltas %>%
ggplot(aes(x=test_ppl, y=delta_test_mean, color=training)) +
geom_errorbar(aes(ymin=delta_test_mean-delta_test_sem, ymax=delta_test_mean+delta_test_sem), alpha=0.4) +
#geom_smooth(method="lm", se=F) +
geom_point(stat="identity", position="dodge", alpha=1, size=4, aes(shape=model)) +
ylab(metric) +
xlab("Test Perplexity") +
#coord_cartesian(ylim = c(1, 16)) +
ggtitle("Test Perplexity vs. Predictive Power") +
scale_color_manual(values = c("bllip-lg"="#440154FF",
"bllip-md"="#39568CFF",
"bllip-sm"="#1F968BFF",
"bllip-xs"="#73D055FF",
"bllip-lg-gptbpe"="#888888",
"bllip-md-gptbpe"="#AAAAAA",
"bllip-sm-gptbpe"="#CCCCCC",
"bllip-xs-gptbpe"="#CCCCCC")) +
facet_grid(~corpus, scales="free") +
#coord_cartesian(ylim = c(0, 150)) +
theme(axis.text=element_text(size=12),
strip.text.x = element_text(size=12),
legend.text=element_text(size=12),
axis.title=element_text(size=12),
legend.position = "right")
ggsave("../images/cuny2020/ppl_loglik.png",height=4.5,width=11)
model_deltas %>%
#filter(model != "5gram", training != "bllip-lg") %>%
group_by(model, corpus) %>%
#summarise(n = n())
summarise(corr = as.numeric(cor.test(delta_test_mean, test_ppl)[4]),
pval = as.numeric(cor.test(delta_test_mean, test_ppl)[3]))
model_deltas %>%
mutate(train_size = log(train_size)) %>%
ggplot(aes(x=train_size, y=delta_test_mean, color=model)) +
geom_errorbar(aes(ymin=delta_test_mean-delta_test_sem, ymax=delta_test_mean+delta_test_sem), width = 0.1) +
geom_smooth(method="lm", se=T, alpha=0.5) +
geom_point(stat="identity", position="dodge", alpha=1, size=3) +
ylab(metric) +
xlab("Log Million Training Tokens") +
ggtitle("Training Size vs. Predictive Power") +
facet_grid(corpus~model, scales="free") +
#scale_color_manual(values = c("#A42EF1", "#3894C8")) +
theme(axis.text=element_text(size=14),
strip.text.x = element_text(size=14),
legend.text=element_text(size=14),
axis.title=element_text(size=18),
legend.position = "bottom")
#ggsave("./cogsci_images/training_loglik.png",height=5,width=6)
model_deltas %>%
mutate(train_size = log(train_size)) %>%
ggplot(aes(x=train_size, y=sg_score, color=model)) +
geom_smooth(method="lm", se=T, alpha=0.5) +
geom_point(stat="identity", position="dodge", alpha=1, size=3) +
ylab("SG SCore") +
xlab("Log Million Training Tokens") +
ggtitle("Training Size vs. Syntactic Generalization") +
#scale_color_manual(values = c("#A42EF1", "#3894C8")) +
facet_grid(~model, scales="free") +
theme(axis.text=element_text(size=14),
strip.text.x = element_text(size=14),
legend.text=element_text(size=14),
axis.title=element_text(size=18),
legend.position = "bottom")
#ggsave("./cogsci_images/training_sg.png",height=5,width=6)
all_data %>%
filter(surprisal < 15, surprisal > 0) %>%
ggplot(aes(x=surprisal, y=psychometric, color=training)) +
stat_smooth(se=T, alpha=0.5) +
#geom_errorbar(color="black", width=.2, position=position_dodge(width=.9), alpha=0.3) +
#geom_point(stat="identity", position="dodge", alpha=1, size=3) +
ylab("Processing Time (ms)") +
xlab("Surprisal (bits)") +
ggtitle("Surprisal vs. Reading Time / Gaze Duration") +
facet_grid(corpus~model, scales = "free") +
scale_color_manual(values = c("bllip-lg"="#440154FF",
"bllip-md"="#39568CFF",
"bllip-sm"="#1F968BFF",
"bllip-xs"="#73D055FF",
"bllip-lg-gptbpe"="#888888",
"bllip-md-gptbpe"="#AAAAAA",
"bllip-sm-gptbpe"="#CCCCCC",
"bllip-xs-gptbpe"="#CCCCCC")) +
theme(axis.text=element_text(size=14),
axis.text.y = element_text(size = 10),
strip.text.x = element_text(size=14),
legend.text=element_text(size=14),
axis.title=element_text(size=18),
legend.position = "right")
ggsave("../images/cuny2020/surp_corr.png",height=4.5,width=12)